{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DP Tensor Experiment - v5\n",
    "\n",
    "### Purpose:\n",
    "Vectorize v4 - will referneec v3\n",
    "\n",
    "### Conclusions:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as th"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 307,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "max_long = 2**62\n",
    "\n",
    "class PrivateTensor():\n",
    "    \n",
    "    def __init__(self, values, max_vals, min_vals, entities=None):\n",
    "        \n",
    "        self.values = values\n",
    "        self.max_vals = max_vals\n",
    "        self.min_vals = min_vals\n",
    "        \n",
    "        if(entities is None):\n",
    "            entities = self.max_vals != self.min_vals\n",
    "        \n",
    "        # one hot encoding of entities in the ancestry of this tensor\n",
    "        self.entities = entities\n",
    "        \n",
    "    def __add__(self, other):\n",
    "        \n",
    "        # add to a private number\n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "            new_vals = self.values + other.values\n",
    "            new_max_vals = (self.max_vals * self.entities) + (other.max_vals * other.entities)\n",
    "            new_min_vals = (self.min_vals * self.entities) + (other.min_vals * other.entities)\n",
    "            \n",
    "        else:\n",
    "            # add to a public number\n",
    "            new_vals = self.values + other\n",
    "            new_max_vals = self.max_vals + other\n",
    "            new_min_vals = self.min_vals + other\n",
    "        \n",
    "        return PrivateTensor(new_vals,\n",
    "                    new_max_vals,\n",
    "                    new_min_vals) \n",
    "    \n",
    "    def __sub__(self, other):\n",
    "        \n",
    "        # add to a private number\n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "            \n",
    "            new_vals = self.values - other.values\n",
    "            \n",
    "            # note that other.max/min values are reversed on purpose\n",
    "            # because this functionality is equivalent to\n",
    "            # output = self + (other * -1) and multiplication by\n",
    "            # a negative number swaps the max/min values with each\n",
    "            # other and flips their sign\n",
    "            new_max_vals = (self.entities * self.max_vals) - (other.entities * other.min_vals)\n",
    "            new_min_vals = (self.entities * self.min_vals) - (other.entities * other.max_vals)\n",
    "            \n",
    "        else:\n",
    "            # add to a public number\n",
    "            new_vals = self.values - other\n",
    "            new_max_vals = self.max_vals - other\n",
    "            new_min_vals = self.min_vals - other\n",
    "        \n",
    "        return PrivateTensor(new_vals,\n",
    "                    new_max_vals,\n",
    "                    new_min_vals) \n",
    "    \n",
    "    def __mul__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "            \n",
    "            new_vals = self.values * other.values\n",
    "            \n",
    "            new_self_max_vals = th.max(self.min_vals * other.expanded_minminvals,\n",
    "                                         self.max_vals * other.expanded_maxmaxvals)\n",
    "            \n",
    "            new_self_min_vals = th.min(self.min_vals * other.expanded_maxmaxvals,\n",
    "                                         self.max_vals * other.expanded_minminvals)\n",
    "            \n",
    "            \n",
    "            new_other_max_vals = th.max(other.min_vals * self.expanded_minminvals,\n",
    "                                          other.max_vals * self.expanded_maxmaxvals)\n",
    "            \n",
    "            new_other_min_vals = th.max(other.min_vals * self.expanded_maxmaxvals,\n",
    "                                          other.max_vals * self.expanded_minminvals)\n",
    "            \n",
    "            \n",
    "            entities_self_or_other = (self.entities + other.entities) > 0\n",
    "            \n",
    "            new_self_max_vals = (new_self_max_vals * self.entities) + ((1 - self.entities) * -max_long)\n",
    "            new_other_max_vals = (new_other_max_vals * self.entities) + ((1 - other.entities) * -max_long)\n",
    "            \n",
    "            new_max_vals = th.max(new_self_max_vals, new_other_max_vals) * entities_self_or_other\n",
    "            \n",
    "            new_self_min_vals = (new_self_min_vals * self.entities) + ((1 - self.entities) * max_long)\n",
    "            new_other_min_vals = (new_other_min_vals * self.entities) + ((1 - other.entities) * max_long)\n",
    "            \n",
    "            new_min_vals = th.min(new_self_min_vals, new_other_min_vals) * entities_self_or_other\n",
    "            \n",
    "        \n",
    "        else:\n",
    "            # add to a public number\n",
    "            new_vals = self.values * other\n",
    "\n",
    "            if(other > 0):\n",
    "                new_max_vals = self.max_vals * other\n",
    "                new_min_vals = self.min_vals * other\n",
    "            else:\n",
    "                new_min_vals = self.max_vals * other\n",
    "                new_max_vals = self.min_vals * other\n",
    "        \n",
    "        return PrivateTensor(new_vals,\n",
    "                    new_max_vals,\n",
    "                    new_min_vals) \n",
    "    \n",
    "    def __neg__(self):\n",
    "        \n",
    "        # note that new_min_vals and new_max_vals are reversed intentionally\n",
    "        return PrivateTensor(-self.values,\n",
    "                    -self.min_vals,\n",
    "                    -self.max_vals) \n",
    "\n",
    "    def __truediv__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "            raise Exception(\"probably best not to do this - it's gonna be inf a lot\")\n",
    "            \n",
    "        new_vals = self.values / other\n",
    "        new_max_vals = self.max_vals / other\n",
    "        new_min_vals = self.min_vals / other\n",
    "        \n",
    "        return PrivateTensor(new_vals,\n",
    "                                new_max_vals,\n",
    "                                new_min_vals)\n",
    "    \n",
    "    def __gt__(self, other):\n",
    "        \"\"\"BUG: the zero values mess this up\"\"\"\n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "        \n",
    "            new_vals = self.values > other.values\n",
    "        \n",
    "            # if self is bigger than the biggest possible other\n",
    "            if_left = (self.min_vals > other.expanded_maxmaxvals).float() * self.entities\n",
    "            \n",
    "            # if self is smaller than the smallest possible other\n",
    "            if_right = (self.max_vals < other.expanded_minminvals).float() * self.entities\n",
    "            \n",
    "            # if self doesn't overlap with other at all\n",
    "            if_left_or_right = if_left + if_right # shouldn't have to check if this > 2 assuming\n",
    "                                                  # other's max is > other's min\n",
    "                \n",
    "            # if self does overlap with other\n",
    "            new_self_max_vals = 1 - if_left_or_right\n",
    "            \n",
    "            # can't have a threshold output less than 0\n",
    "            new_self_min_vals = if_left_or_right * 0\n",
    "            \n",
    "            # if other is bigger than the smallest possible self\n",
    "            if_left = (other.min_vals > self.expanded_maxmaxvals).float() * other.entities\n",
    "            \n",
    "            # if other is smaller than the smallest possible self\n",
    "            if_right = (other.max_vals < self.expanded_minminvals).float() * other.entities\n",
    "            \n",
    "            # if other and self don't overlap\n",
    "            if_left_or_right = if_left + if_right # shouldn't have to check if this > 2 assuming\n",
    "                                                  # other's max is > other's min\n",
    "        \n",
    "            # if other and self do overlap\n",
    "            new_other_max_vals = 1 - if_left_or_right\n",
    "            \n",
    "            # the smallest possible result is 0\n",
    "            new_other_min_vals = new_self_min_vals + 0\n",
    "        \n",
    "            # only contribute information from entities in ancestry\n",
    "            new_self_max_vals = (new_self_max_vals * self.entities) + ((1 - self.entities) * -max_long)\n",
    "            new_other_max_vals = (new_other_max_vals * self.entities) + ((1 - self.entities) * -max_long)\n",
    "        \n",
    "            # only contribute information from entities in ancestry\n",
    "            new_self_min_vals = (new_self_min_vals * self.entities) + ((1 - self.entities) * max_long)\n",
    "            new_other_min_vals = (new_other_min_vals * self.entities) + ((1 - self.entities) * max_long)\n",
    "            \n",
    "            entities_self_or_other = ((self.entities + other.entities) > 0).float()\n",
    "            \n",
    "            new_max_val = th.max(new_self_max_vals, new_other_max_vals) * entities_self_or_other\n",
    "            new_min_val = th.min(new_self_min_vals, new_other_min_vals) * entities_self_or_other\n",
    "\n",
    "        else:\n",
    "\n",
    "            new_vals = self.values > other\n",
    "            \n",
    "            if_left = other <= self.max_vals\n",
    "            if_right = other >= self.min_vals\n",
    "            if_and = if_left * if_right\n",
    "            \n",
    "            new_max_val = if_and\n",
    "            new_min_val = new_max_val * 0\n",
    "            \n",
    "        return PrivateTensor(new_vals,\n",
    "                             new_max_val,\n",
    "                             new_min_val)\n",
    "    \n",
    "    def __lt__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "        \n",
    "            result = self.values > other.values\n",
    "        \n",
    "            # if self is bigger than the biggest possible other\n",
    "            if_left = (self.min_vals > other.expanded_maxmaxvals).float() * self.entities\n",
    "            \n",
    "            # if self is smaller than the smallest possible other\n",
    "            if_right = (self.max_vals < other.expanded_minminvals).float() * self.entities\n",
    "            \n",
    "            # if self doesn't overlap with other at all\n",
    "            if_left_or_right = if_left + if_right # shouldn't have to check if this > 2 assuming\n",
    "                                                  # other's max is > other's min\n",
    "                \n",
    "            # if self does overlap with other\n",
    "            new_self_max_vals = 1 - if_left_or_right\n",
    "            \n",
    "            # can't have a threshold output less than 0\n",
    "            new_self_min_vals = if_left_or_right * 0\n",
    "            \n",
    "            \n",
    "            # if other is bigger than the smallest possible self\n",
    "            if_left = (other.min_vals > self.expanded_maxmaxvals).float() * other.entities\n",
    "            \n",
    "            # if other is smaller than the smallest possible self\n",
    "            if_right = (other.max_vals < self.expanded_minminvals).float() * other.entities\n",
    "            \n",
    "            # if other and self don't overlap\n",
    "            if_left_or_right = if_left + if_right # shouldn't have to check if this > 2 assuming\n",
    "                                                  # other's max is > other's min\n",
    "        \n",
    "            # if other and self do overlap\n",
    "            new_other_max_vals = 1 - if_left_or_right\n",
    "            \n",
    "            # the smallest possible result is 0\n",
    "            new_other_min_vals = new_self_min_vals + 0\n",
    "                \n",
    "            # only contribute information from entities in ancestry\n",
    "            new_self_max_vals = (new_self_max_vals * self.entities) + ((1 - self.entities) * -max_long)\n",
    "            new_other_max_vals = (new_other_max_vals * self.entities) + ((1 - self.entities) * -max_long)\n",
    "        \n",
    "            # only contribute information from entities in ancestry\n",
    "            new_self_min_vals = (new_self_min_vals * self.entities) + ((1 - self.entities) * max_long)\n",
    "            new_other_min_vals = (new_other_min_vals * self.entities) + ((1 - self.entities) * max_long)\n",
    "            \n",
    "            entities_self_or_other = ((self.entities + other.entities) > 0).float()\n",
    "            \n",
    "            new_max_val = th.max(new_self_max_vals, new_other_max_vals) * entities_self_or_other\n",
    "            new_min_val = th.min(new_self_min_vals, new_other_min_vals) * entities_self_or_other\n",
    "        \n",
    "        else:\n",
    "\n",
    "            result = self.values < other\n",
    "            \n",
    "            if_left = other <= self.max_vals\n",
    "            if_right = other >= self.min_vals\n",
    "            if_and = if_left * if_right\n",
    "            \n",
    "            new_max_val = if_and\n",
    "            new_min_val = new_max_val * 0\n",
    "            \n",
    "        return PrivateTensor(result,\n",
    "                             new_max_val,\n",
    "                             new_min_val)\n",
    "\n",
    "    def clamp_min(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "            raise Exception(\"Not implemented yet\")\n",
    "        \n",
    "        new_min_val = self.min_vals.clamp_min(other)\n",
    "            \n",
    "        return PrivateTensor(self.values.clamp_min(other),\n",
    "                                self.max_vals,\n",
    "                                new_min_val)\n",
    "    \n",
    "    def clamp_max(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateTensor)):\n",
    "            raise Exception(\"Not implemented yet\")\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = self.max_vals.clamp_max(other)\n",
    "                \n",
    "        return PrivateTensor(self.values.clamp_max(other),\n",
    "                                new_max_val,\n",
    "                                self.min_vals)\n",
    "    \n",
    "    @property\n",
    "    def maxmaxvals(self):\n",
    "        \"\"\"This returns the maximum possible value over all entities\"\"\"\n",
    "        \n",
    "        return (self.max_vals * self.entities).max(1)[0]\n",
    "    \n",
    "    @property\n",
    "    def expanded_maxmaxvals(self):\n",
    "        return self.maxmaxvals.unsqueeze(1).expand(self.max_vals.shape)\n",
    "    \n",
    "    @property\n",
    "    def minminvals(self):\n",
    "        \"\"\"This returns the minimum possible values over all entities\"\"\"\n",
    "\n",
    "        return (self.min_vals * self.entities).min(1)[0]\n",
    "    \n",
    "    @property\n",
    "    def expanded_minminvals(self):\n",
    "        return self.minminvals.unsqueeze(1).expand(self.min_vals.shape)\n",
    "    \n",
    "    @property\n",
    "    def sensitivity(self):\n",
    "        return (self.max_vals - self.min_vals).max(1)[0]\n",
    "    \n",
    "    @property\n",
    "    def entities(self):\n",
    "        return self._entities\n",
    "        \n",
    "    @entities.setter\n",
    "    def entities(self, x):\n",
    "        self._entities = x.float()\n",
    "        \n",
    "    def hard_sigmoid(self):\n",
    "        return self.min(1).max(0)\n",
    "    \n",
    "    def hard_sigmoid_deriv(self, leak=0.01):\n",
    "        return ((self < 1) * (self > 0)) + (self < 0) * leak - (self > 1) * leak\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 308,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = th.tensor([2.,0,1,0,0,0])\n",
    "max_vals = th.tensor([[3.,1,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,0,0], [0,0,0,0,1], [0,0,0,0,1]])\n",
    "min_vals = th.tensor([[2.,0,0,0,0], [0,0,0,0,0], [0,0,0,0,0], [0,0,0,0,0], [0,0,0,0,0], [0,0,0,0,0]])\n",
    "x = PrivateTensor(values=data, \n",
    "                  max_vals=max_vals, \n",
    "                  min_vals=min_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 309,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = x.clamp_min(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 310,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 1., 1., 1., 1., 1.])"
      ]
     },
     "execution_count": 310,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 311,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1.,  0.,  0., -1.,  0.,  0.])"
      ]
     },
     "execution_count": 311,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.sensitivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 314,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 0., 1., 0., 0., 0.])"
      ]
     },
     "execution_count": 314,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.values.clamp_max(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 315,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 1., 1., 1., 1., 1.])"
      ]
     },
     "execution_count": 315,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import numpy as np\n",
    "\n",
    "class PrivateNumber():\n",
    "    \n",
    "    def __init__(self, value, max_val, min_val):\n",
    "        self.value = value\n",
    "        self.max_val = max_val\n",
    "        self.min_val = min_val\n",
    "        \n",
    "    def __add__(self, other):\n",
    "        \n",
    "        # add to a private number\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "\n",
    "            entities = self.entities.union(other.entities)\n",
    "            \n",
    "            new_val = self.value + other.value\n",
    "\n",
    "            entities = set(self.max_val.keys()).union(set(other.max_val.keys()))\n",
    "\n",
    "            new_max_val = {}\n",
    "            new_min_val = {}\n",
    "            for entity in entities:\n",
    "                \n",
    "                self_max = self.max_val[entity] if entity in self.max_val else 0\n",
    "                other_max = other.max_val[entity] if entity in other.max_val else 0\n",
    "                \n",
    "                self_min = self.min_val[entity] if entity in self.min_val else 0\n",
    "                other_min = other.min_val[entity] if entity in other.min_val else 0\n",
    "                \n",
    "                new_max_val[entity] = self_max + other_max\n",
    "                new_min_val[entity] = self_min + other_min\n",
    "\n",
    "        else:\n",
    "\n",
    "            entities = self.entities\n",
    "\n",
    "            # add to a public number\n",
    "            new_val = self.value + other\n",
    "            \n",
    "            new_max_val = {}\n",
    "            new_min_val = {}\n",
    "            for entity in entities:\n",
    "                new_max_val[entity] = self.max_val[entity] + other\n",
    "                new_min_val[entity] = self.min_val[entity] + other\n",
    "\n",
    "        return PrivateNumber(new_val,\n",
    "                             new_max_val,\n",
    "                             new_min_val)\n",
    "    \n",
    "    def __mul__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "        \n",
    "            new_self_max_val = {}\n",
    "            new_self_min_val = {}\n",
    "            for entity in self.entities:\n",
    "                \n",
    "                # the biggest positive number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the same sign from other\n",
    "                new_self_max_val[entity] = max(self.min_val[entity] * other.xmin, \n",
    "                                               self.max_val[entity] * other.xmax)\n",
    "                \n",
    "                # the smallest negative number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the opposite sign from other\n",
    "                new_self_min_val[entity] = min(self.min_val[entity] * other.xmax,\n",
    "                                               self.max_val[entity] * other.xmin)\n",
    "                \n",
    "            new_other_max_val = {}\n",
    "            new_other_min_val = {}\n",
    "            for entity in other.entities:\n",
    "                \n",
    "                # the biggest positive number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the same sign from other\n",
    "                new_other_max_val[entity] = max(other.min_val[entity] * self.xmin, \n",
    "                                                other.max_val[entity] * self.xmax)\n",
    "                \n",
    "                # the smallest negative number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the opposite sign from other\n",
    "                new_other_min_val[entity] = min(other.min_val[entity] * self.xmax,\n",
    "                                                other.max_val[entity] * self.xmin)\n",
    "                \n",
    "            new_max_val = {}\n",
    "            new_min_val = {}\n",
    "            \n",
    "            for entity in self.entities:\n",
    "                left = new_self_max_val[entity] if entity in new_self_max_val else -(2**32)\n",
    "                right = new_other_max_val[entity] if entity in new_other_max_val else -(2**32)\n",
    "                \n",
    "                new_max_val[entity] = max(left, right)\n",
    "                \n",
    "                left = new_self_min_val[entity] if entity in new_self_min_val else 2**32\n",
    "                right = new_other_min_val[entity] if entity in new_other_min_val else 2**32\n",
    "                \n",
    "                new_min_val[entity] = min(left, right)\n",
    "\n",
    "            return PrivateNumber(self.value * other.value,\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "            \n",
    "    \n",
    "        new_max_val = {}\n",
    "        for entity in self.entities:\n",
    "            new_max_val[entity] = self.max_val[entity] * other\n",
    "\n",
    "        new_min_val = {}\n",
    "        for entity in self.entities:\n",
    "            new_min_val[entity] = self.min_val[entity] * other\n",
    "        \n",
    "        if(other > 0):\n",
    "            return PrivateNumber(self.value * other,\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "        else:\n",
    "            return PrivateNumber(self.value * other,\n",
    "                                    new_min_val,                                 \n",
    "                                    new_max_val)\n",
    "    \n",
    "    def __neg__(self):\n",
    "        return self * -1\n",
    "    \n",
    "    def __sub__(self, other):\n",
    "        return self + (-other)\n",
    "\n",
    "    def __gt__(self, other):\n",
    "\n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "        \n",
    "            new_self_max_val = {}\n",
    "            new_self_min_val = {}\n",
    "            for entity in self.entities:\n",
    "                \n",
    "                if not (self.min_val[entity] > other.xmax or self.max_val[entity] < other.xmin):\n",
    "                    new_self_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_self_max_val[entity] = 0\n",
    "                \n",
    "                new_self_min_val[entity] = 0\n",
    "                \n",
    "            new_other_max_val = {}\n",
    "            new_other_min_val = {}\n",
    "            for entity in other.entities:\n",
    "                \n",
    "                if not (other.min_val[entity] > self.xmax or other.max_val[entity] < self.xmin):\n",
    "                    new_other_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_other_max_val[entity] = 0\n",
    "                    \n",
    "                new_other_min_val[entity] = 0\n",
    "                \n",
    "            new_max_val = {}\n",
    "            new_min_val = {}\n",
    "            \n",
    "            entities = self.entities.union(other.entities)\n",
    "                \n",
    "            for entity in entities:\n",
    "                \n",
    "                new_self_max = new_self_max_val[entity] if entity in new_self_max_val else -999999999\n",
    "                new_other_max = new_other_max_val[entity] if entity in new_other_max_val else -999999999\n",
    "                \n",
    "                new_self_min = new_self_min_val[entity] if entity in new_self_min_val else 99999999\n",
    "                new_other_min = new_other_min_val[entity] if entity in new_other_min_val else 99999999\n",
    "                \n",
    "                new_max_val[entity] = max(new_self_max, new_other_max)\n",
    "                new_min_val[entity] = min(new_self_min, new_other_min)\n",
    "\n",
    "            result = int(self.value > other.value)\n",
    "        else:\n",
    "\n",
    "            entities = self.entities.union(other.entities)\n",
    "\n",
    "            new_max_val = {}\n",
    "            new_min_val = {}\n",
    "            for entity in entities:\n",
    "\n",
    "                new_min_val[entity] = 0\n",
    "\n",
    "                if(other <= self.max_val[entity] and other >= self.min_val[entity]):    \n",
    "                    new_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_max_val[entity] = 0\n",
    "\n",
    "            result = int(self.value > other)\n",
    "            \n",
    "        return PrivateNumber(result,\n",
    "                                new_max_val,\n",
    "                                new_min_val)\n",
    "    \n",
    "    def __lt__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "        \n",
    "            entities = self.entities.union(other.entities)\n",
    "        \n",
    "            new_self_max_val = {}\n",
    "            new_self_min_val = {}\n",
    "            \n",
    "            for entity in self.entities:\n",
    "                \n",
    "                if not (self.min_val[entity] > other.xmax or self.max_val[entity] < other.xmin):\n",
    "                    new_self_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_self_max_val[entity] = 0\n",
    "                \n",
    "                new_self_min_val[entity] = 0\n",
    "                \n",
    "            new_other_max_val = {}\n",
    "            new_other_min_val = {}\n",
    "            \n",
    "            for entity in other.entities:\n",
    "                \n",
    "                if not (other.min_val[entity] > self.xmax or other.max_val[entity] < self.xmin):\n",
    "                    new_other_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_other_max_val[entity] = 0\n",
    "                    \n",
    "                new_other_min_val[entity] = 0\n",
    "                \n",
    "            new_max_val = {}\n",
    "            new_min_val = {}\n",
    "            \n",
    "            entities = self.entities.union(other.entities)\n",
    "                \n",
    "            for entity in entities:\n",
    "                \n",
    "                new_self_max = new_self_max_val[entity] if entity in new_self_max_val else -999999999\n",
    "                new_other_max = new_other_max_val[entity] if entity in new_other_max_val else -999999999\n",
    "                \n",
    "                new_self_min = new_self_min_val[entity] if entity in new_self_min_val else 99999999\n",
    "                new_other_min = new_other_min_val[entity] if entity in new_other_min_val else 99999999\n",
    "                \n",
    "                new_max_val[entity] = max(new_self_max, new_other_max)\n",
    "                new_min_val[entity] = min(new_self_min, new_other_min)\n",
    "\n",
    "            result = int(self.value < other.value)\n",
    "        \n",
    "        else:\n",
    "\n",
    "            entities = self.entities\n",
    "\n",
    "            new_max_val = {}\n",
    "            new_min_val = {}\n",
    "            for entity in entities:\n",
    "\n",
    "                new_min_val[entity] = 0\n",
    "\n",
    "                if(other <= self.max_val[entity] and other >= self.min_val[entity]):    \n",
    "                    new_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_max_val[entity] = 0\n",
    "\n",
    "            result = int(self.value < other)\n",
    "        \n",
    "        return PrivateNumber(result,\n",
    "                             new_max_val,\n",
    "                             new_min_val)\n",
    "    \n",
    "    def max(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "            raise Exception(\"Not implemented yet\")\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_min_val = {}\n",
    "        for entity in entities:\n",
    "            new_min_val[entity] = max(self.min_val[entity], other)\n",
    "            \n",
    "        return PrivateNumber(max(self.value, other),\n",
    "                                self.max_val,\n",
    "                                new_min_val)\n",
    "    \n",
    "    def min(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "            raise Exception(\"Not implemented yet\")\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = {}\n",
    "        for entity in entities:\n",
    "            new_max_val[entity] = min(self.max_val[entity], other)\n",
    "                \n",
    "        return PrivateNumber(min(self.value, other),\n",
    "                                new_max_val,\n",
    "                                self.min_val)\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return str(self.value) + \" \" + str(self.max_val) + \" \" + str(self.min_val)\n",
    "    \n",
    "    def hard_sigmoid(self):\n",
    "        return self.min(1).max(0)\n",
    "    \n",
    "    def hard_sigmoid_deriv(self):\n",
    "        return ((self < 1) * (self > 0)) + (self < 0) * 0.01 - (self > 1) * 0.01\n",
    "    \n",
    "    @property\n",
    "    def xmin(self):\n",
    "        items = list(self.min_val.items())\n",
    "        out = items[0][1]\n",
    "        for k,v in items[1:]:\n",
    "            if(v < out):\n",
    "                out = v\n",
    "        return out\n",
    "    \n",
    "    @property\n",
    "    def xmax(self):\n",
    "        items = list(self.max_val.items())\n",
    "        out = items[0][1]\n",
    "        for k,v in items[1:]:\n",
    "            if(v > out):\n",
    "                out = v\n",
    "        return out\n",
    "    \n",
    "    @property\n",
    "    def entities(self):\n",
    "        return set(self.max_val.keys())\n",
    "    \n",
    "    @property\n",
    "    def sensitivity(self):\n",
    "        sens = Counter()\n",
    "        for entity, value in self.max_val.items():\n",
    "            sens[entity] = value - self.min_val[entity]\n",
    "        return sens.most_common()[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = PrivateNumber(0.5,{\"bob\":4, \"amos\":3},{\"bob\":3, \"amos\":2})\n",
    "y = PrivateNumber(1,{\"bob\":1},{\"bob\":-1})\n",
    "z = PrivateNumber(-0.5,{\"sue\":2},{\"sue\":-1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = y < z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.sensitivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import numpy as np\n",
    "class PrivateNumber():\n",
    "    \n",
    "    def __init__(self, value, max_val, min_val):\n",
    "        self.value = value\n",
    "        self.max_val = max_val\n",
    "        self.min_val = min_val\n",
    "        \n",
    "    def __add__(self, other):\n",
    "        \n",
    "        # add to a private number\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "\n",
    "            entities = self.entities.union(other.entities)\n",
    "            \n",
    "            new_val = self.value + other.value\n",
    "\n",
    "            entities = set(self.max_val.keys()).union(set(other.max_val.keys()))\n",
    "\n",
    "            new_max_val = Counter()\n",
    "            new_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                new_max_val[entity] = self.max_val[entity] + other.max_val[entity]\n",
    "                new_min_val[entity] = self.min_val[entity] + other.min_val[entity]\n",
    "\n",
    "            return PrivateNumber(self.value + other.value,\n",
    "                                new_max_val,\n",
    "                                new_min_val)\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        # add to a public number\n",
    "        \n",
    "        new_max_val = Counter()\n",
    "        new_min_val = Counter()        \n",
    "        for entity in entities:\n",
    "            new_max_val[entity] = self.max_val[entity] + other\n",
    "            new_min_val[entity] = self.min_val[entity] + other\n",
    "        \n",
    "        return PrivateNumber(self.value + other,\n",
    "                                new_max_val,\n",
    "                                new_min_val)\n",
    "\n",
    "    def __mul__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "        \n",
    "            entities = self.entities.union(other.entities)\n",
    "        \n",
    "            new_self_max_val = Counter()\n",
    "            new_self_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                # the biggest positive number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the same sign from other\n",
    "                new_self_max_val[entity] = max(self.min_val[entity] * other.xmin, \n",
    "                                               self.max_val[entity] * other.xmax)\n",
    "                \n",
    "                # the smallest negative number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the opposite sign from other\n",
    "                new_self_min_val[entity] = min(self.min_val[entity] * other.xmax,\n",
    "                                               self.max_val[entity] * other.xmin)\n",
    "                \n",
    "            new_other_max_val = Counter()\n",
    "            new_other_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                # the biggest positive number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the same sign from other\n",
    "                new_other_max_val[entity] = max(other.min_val[entity] * self.xmin, \n",
    "                                                other.max_val[entity] * self.xmax)\n",
    "                \n",
    "                # the smallest negative number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the opposite sign from other\n",
    "                new_other_min_val[entity] = min(other.min_val[entity] * self.xmax,\n",
    "                                                other.max_val[entity] * self.xmin)\n",
    "                \n",
    "            new_max_val = Counter()\n",
    "            new_min_val = Counter()\n",
    "            \n",
    "            for entity in entities:\n",
    "                new_max_val[entity] = max(new_self_max_val[entity], new_other_max_val[entity])\n",
    "                new_min_val[entity] = min(new_self_min_val[entity], new_other_min_val[entity])\n",
    "\n",
    "            return PrivateNumber(self.value * other.value,\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_max_val[entity] = self.max_val[entity] * other\n",
    "\n",
    "        new_min_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_min_val[entity] = self.min_val[entity] * other\n",
    "        \n",
    "        if(other > 0):\n",
    "            return PrivateNumber(self.value * other,\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "        else:\n",
    "            return PrivateNumber(self.value * other,\n",
    "                                    new_min_val,                                 \n",
    "                                    new_max_val)\n",
    "    \n",
    "    def __sub__(self, other):\n",
    "        return self + (-other)\n",
    "    \n",
    "    def __mul__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "        \n",
    "            entities = self.entities.union(other.entities)\n",
    "        \n",
    "            new_self_max_val = Counter()\n",
    "            new_self_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                # the biggest positive number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the same sign from other\n",
    "                new_self_max_val[entity] = max(self.min_val[entity] * other.xmin, \n",
    "                                               self.max_val[entity] * other.xmax)\n",
    "                \n",
    "                # the smallest negative number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the opposite sign from other\n",
    "                new_self_min_val[entity] = min(self.min_val[entity] * other.xmax,\n",
    "                                               self.max_val[entity] * other.xmin)\n",
    "                \n",
    "            new_other_max_val = Counter()\n",
    "            new_other_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                # the biggest positive number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the same sign from other\n",
    "                new_other_max_val[entity] = max(other.min_val[entity] * self.xmin, \n",
    "                                                other.max_val[entity] * self.xmax)\n",
    "                \n",
    "                # the smallest negative number this entity could contribute is when\n",
    "                # it is multiplied by the largest value of the opposite sign from other\n",
    "                new_other_min_val[entity] = min(other.min_val[entity] * self.xmax,\n",
    "                                                other.max_val[entity] * self.xmin)\n",
    "                \n",
    "            new_max_val = Counter()\n",
    "            new_min_val = Counter()\n",
    "            \n",
    "            for entity in entities:\n",
    "                new_max_val[entity] = max(new_self_max_val[entity], new_other_max_val[entity])\n",
    "                new_min_val[entity] = min(new_self_min_val[entity], new_other_min_val[entity])\n",
    "\n",
    "            return PrivateNumber(self.value * other.value,\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_max_val[entity] = self.max_val[entity] * other\n",
    "\n",
    "        new_min_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_min_val[entity] = self.min_val[entity] * other\n",
    "        \n",
    "        if(other > 0):\n",
    "            return PrivateNumber(self.value * other,\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "        else:\n",
    "            return PrivateNumber(self.value * other,\n",
    "                                    new_min_val,                                 \n",
    "                                    new_max_val)\n",
    "    \n",
    "    def __truediv__(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "            raise Exception(\"probably best not to do this - it's gonna be inf a lot\")\n",
    "            \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_max_val[entity] = self.max_val[entity] / other\n",
    "\n",
    "        new_min_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_min_val[entity] = self.min_val[entity] / other\n",
    "        \n",
    "        return PrivateNumber(self.value / other,\n",
    "                                new_max_val,\n",
    "                                new_min_val)\n",
    "\n",
    "    def __gt__(self, other):\n",
    "        \"\"\"BUG!: Counter() defaults to 0\"\"\"\n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "        \n",
    "            entities = self.entities.union(other.entities)\n",
    "        \n",
    "            new_self_max_val = Counter()\n",
    "            new_self_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                if not (self.min_val[entity] > other.xmax or self.max_val[entity] < other.xmin):\n",
    "                    new_self_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_self_max_val[entity] = 0\n",
    "                \n",
    "                new_self_min_val[entity] = 0\n",
    "                \n",
    "            new_other_max_val = Counter()\n",
    "            new_other_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                if not (other.min_val[entity] > self.xmax or other.max_val[entity] < self.xmin):\n",
    "                    new_other_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_other_max_val[entity] = 0\n",
    "                    \n",
    "                new_other_min_val[entity] = 0\n",
    "                \n",
    "            new_max_val = Counter()\n",
    "            new_min_val = Counter()\n",
    "            \n",
    "            for entity in entities:\n",
    "                new_max_val[entity] = max(new_self_max_val[entity], new_other_max_val[entity])\n",
    "                new_min_val[entity] = min(new_self_min_val[entity], new_other_min_val[entity])\n",
    "\n",
    "            return PrivateNumber(int(self.value > other.value),\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = Counter()\n",
    "        new_min_val = Counter()\n",
    "        for entity in entities:\n",
    "            \n",
    "            new_min_val[entity] = 0\n",
    "            \n",
    "            if(other <= self.max_val[entity] and other >= self.min_val[entity]):    \n",
    "                new_max_val[entity] = 1\n",
    "            else:\n",
    "                new_max_val[entity] = 0\n",
    "\n",
    "        return PrivateNumber(int(self.value > other),\n",
    "                                new_max_val,\n",
    "                                new_min_val)\n",
    "    \n",
    "\n",
    "    def __lt__(self, other):\n",
    "        \"\"\"BUG!: Counter() defaults to 0\"\"\"\n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "        \n",
    "            entities = self.entities.union(other.entities)\n",
    "        \n",
    "            new_self_max_val = Counter()\n",
    "            new_self_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                if not (self.min_val[entity] > other.xmax or self.max_val[entity] < other.xmin):\n",
    "                    new_self_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_self_max_val[entity] = 0\n",
    "                \n",
    "                new_self_min_val[entity] = 0\n",
    "                \n",
    "            new_other_max_val = Counter()\n",
    "            new_other_min_val = Counter()            \n",
    "            for entity in entities:\n",
    "                \n",
    "                if not (other.min_val[entity] > self.xmax or other.max_val[entity] < self.xmin):\n",
    "                    new_other_max_val[entity] = 1\n",
    "                else:\n",
    "                    new_other_max_val[entity] = 0\n",
    "                    \n",
    "                new_other_min_val[entity] = 0\n",
    "                \n",
    "            new_max_val = Counter()\n",
    "            new_min_val = Counter()\n",
    "            \n",
    "            for entity in entities:\n",
    "                new_max_val[entity] = max(new_self_max_val[entity], new_other_max_val[entity])\n",
    "                new_min_val[entity] = min(new_self_min_val[entity], new_other_min_val[entity])\n",
    "\n",
    "            return PrivateNumber(int(self.value < other.value),\n",
    "                                    new_max_val,\n",
    "                                    new_min_val)\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = Counter()\n",
    "        new_min_val = Counter()\n",
    "        for entity in entities:\n",
    "            \n",
    "            new_min_val[entity] = 0\n",
    "            \n",
    "            if(other <= self.max_val[entity] and other >= self.min_val[entity]):    \n",
    "                new_max_val[entity] = 1\n",
    "            else:\n",
    "                new_max_val[entity] = 0\n",
    "\n",
    "        return PrivateNumber(int(self.value < other),\n",
    "                                new_max_val,\n",
    "                                new_min_val)\n",
    "    \n",
    "    def __neg__(self):\n",
    "        return self * -1\n",
    "    \n",
    "    def max(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "            raise Exception(\"Not implemented yet\")\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_min_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_min_val[entity] = max(self.min_val[entity], other)\n",
    "            \n",
    "        return PrivateNumber(max(self.value, other),\n",
    "                                self.max_val,\n",
    "                                new_min_val)\n",
    "    \n",
    "    def min(self, other):\n",
    "        \n",
    "        if(isinstance(other, PrivateNumber)):\n",
    "            raise Exception(\"Not implemented yet\")\n",
    "        \n",
    "        entities = self.entities\n",
    "        \n",
    "        new_max_val = Counter()\n",
    "        for entity in entities:\n",
    "            new_max_val[entity] = min(self.max_val[entity], other)\n",
    "                \n",
    "        return PrivateNumber(min(self.value, other),\n",
    "                                new_max_val,\n",
    "                                self.min_val)\n",
    "    \n",
    "    def hard_sigmoid(self):\n",
    "        return self.min(1).max(0)\n",
    "    \n",
    "    def hard_sigmoid_deriv(self):\n",
    "        return ((self < 1) * (self > 0)) + (self < 0) * 0.01 - (self > 1) * 0.01\n",
    "        \n",
    "    def __repr__(self):\n",
    "        return str(self.value) + \" \" + str(self.max_val) + \" \" + str(self.min_val)\n",
    "    \n",
    "    @property\n",
    "    def xmin(self):\n",
    "        return self.min_val.most_common(len(self.min_val))[-1][1]\n",
    "    \n",
    "    @property\n",
    "    def xmax(self):\n",
    "        return self.max_val.most_common(1)[0][1]\n",
    "    \n",
    "    @property\n",
    "    def entities(self):\n",
    "        return set(self.max_val.keys())\n",
    "    \n",
    "    @property\n",
    "    def sensitivity(self):\n",
    "        sens = Counter()\n",
    "        for entity, value in self.max_val.items():\n",
    "            sens[entity] = value - self.min_val[entity]\n",
    "        return sens.most_common()[0][1]\n",
    "    \n",
    "x = PrivateNumber(0.5,Counter({\"bob\":4, \"amos\":3}),Counter({\"bob\":3, \"amos\":2}))\n",
    "y = PrivateNumber(1,Counter({\"bob\":1}),Counter({\"bob\":-1}))\n",
    "z = PrivateNumber(-0.5,Counter({\"sue\":2}),Counter({\"sue\":-1}))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = x > y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.sensitivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = x + y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "b = a * z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1.0 Counter({'sue': 8, 'bob': 8, 'amos': 6}) Counter({'amos': -3, 'sue': -6, 'bob': -6})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class PrivacyAccountant():\n",
    "    \n",
    "#     def __init__(self, default_budget = 0.1):\n",
    "        \n",
    "#         self.entity2epsilon = {}\n",
    "#         self.entity2id = {}\n",
    "#         self.default_budget = default_budget\n",
    "        \n",
    "#     def add_entity(self, entity_id, budget=None):\n",
    "#         \"\"\"Add another entity to the system to be tracked.\n",
    "        \n",
    "#         Args:\n",
    "#             entity_id: a string or other unique identifier of the entity\n",
    "#             budget: the epsilon level defining this user's privacy budget\n",
    "#         \"\"\"\n",
    "        \n",
    "#         if(budget is None):\n",
    "#             budget = self.default_budget\n",
    "        \n",
    "#         self.entity2id[entity_id] = len(self.entity2id)\n",
    "#         self.entity2epsilon[self.entity2id[entity_id]] = budget\n",
    "        \n",
    "        \n",
    "# accountant = PrivacyAccountant()\n",
    "\n",
    "# class DPTensor():\n",
    "    \n",
    "#     def __init__(self, data, entities, max_values=None, min_values=None):\n",
    "        \n",
    "#         assert data.shape == entities.shape#[0:-1]\n",
    "\n",
    "#         self.data = data\n",
    "#         self.entities = entities\n",
    "        \n",
    "#         if max_values is None:\n",
    "#             max_values = np.inf + np.zeros_like(self.data)\n",
    "            \n",
    "#         assert max_values.shape == data.shape\n",
    "#         self.max_values = max_values    \n",
    "        \n",
    "#         if min_values is None:\n",
    "#             min_values = -np.inf + np.zeros_like(self.data)            \n",
    "            \n",
    "#         assert min_values.shape == data.shape            \n",
    "#         self.min_values = min_values\n",
    "\n",
    "#     def sum(self, dim=0):\n",
    "        \n",
    "#         _new_data = self.data.sum(dim)\n",
    "        \n",
    "#         return _new_data\n",
    "    \n",
    "#     @property\n",
    "#     def sensitivity(self):\n",
    "#         return self.max_values - self.min_values\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {},
   "outputs": [],
   "source": [
    "# results, tags = grid.search(\"diabetes\",\"#data\", verbose=False)\n",
    "# dataset = results['alice'][0][0:5][:,0:4]\n",
    "# n_ent = dataset.shape[0]\n",
    "# n_classes = dataset.shape[1]\n",
    "\n",
    "# for i in range(n_ent):\n",
    "#     accountant.add_entity(\"Diabetes Patient #\" + str(i))\n",
    "    \n",
    "# d2 = dataset.clone().get()\n",
    "# entities = th.arange(0,n_ent).view(-1,1).expand(n_ent,n_classes)#.unsqueeze(2)\n",
    "# db = DPTensor(data=d2, \n",
    "#               entities=entities, \n",
    "#               max_values=d2.max(0)[0].expand(n_ent,n_classes), \n",
    "#               min_values=d2.min(0)[0].expand(n_ent,n_classes))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
